
# This class holds the confusion matrix information. 
# It is designed to be called incrementally, as results are obtained 
# from the classifier model. 
#
# At any point, statistics may be obtained by calling the relevant methods.
#
# A two-class example is:
# 
#     Classified      Classified    | 
#     Positive        Negative      | Actual
#     ------------------------------+------------
#         a               b         | Positive
#         c               d         | Negative
#
# Statistical methods will be described with reference to this example.
# 
class ConfusionMatrix
  # Creates a new, empty instance of a confusion matrix.
  #
  # @param labels [<String, Symbol>, ...] if provided, makes the matrix 
  #        use the first label as a default label, and also check 
  #        all operations use one of the pre-defined labels.
  # @raise [ArgumentError] if there are not at least two unique labels, when provided.
  def initialize(*labels)
    @matrix = {}
    @labels = labels.uniq
    if @labels.size == 1
      raise ArgumentError.new("If labels are provided, there must be at least two.")
    else # preset the matrix Hash
      @labels.each do |actual|
        @matrix[actual] = {}
        @labels.each do |predicted|
          @matrix[actual][predicted] = 0
        end
      end
    end
  end

  # Returns a list of labels used in the matrix.
  #
  #  cm = ConfusionMatrix.new
  #  cm.add_for(:pos, :neg)
  #  cm.labels # => [:neg, :pos]
  #
  # @return [Array<String>] labels used in the matrix.
  def labels
    if @labels.size >= 2 # if we defined some labels, return them
      @labels
    else
      result = []

      @matrix.each_pair do |key, predictions|
        result << key
        predictions.each_key do |key|
          result << key
        end
      end

      result.uniq.sort
    end
  end

  # Return the count for (actual,prediction) pair.
  #
  #  cm = ConfusionMatrix.new
  #  cm.add_for(:pos, :neg)
  #  cm.count_for(:pos, :neg) # => 1
  #
  # @param actual [String, Symbol] is actual class of the instance, 
  #        which we expect the classifier to predict
  # @param prediction [String, Symbol] is the predicted class of the instance, 
  #        as output from the classifier
  # @return [Integer] number of observations of (actual, prediction) pair
  # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any 
  #        pre-defined labels in matrix
  def count_for(actual, prediction)
    validate_label actual, prediction
    predictions = @matrix.fetch(actual, {})
    predictions.fetch(prediction, 0)
  end

  # Adds one result to the matrix for a given (actual, prediction) pair of labels.
  # If the matrix was given a pre-defined list of labels on construction, then 
  # these given labels must be from the pre-defined list. 
  # If no pre-defined list of labels was used in constructing the matrix, then 
  # labels will be added to matrix.
  #
  # Class labels may be any hashable value, though ideally they are strings or symbols.
  #
  # @param actual [String, Symbol] is actual class of the instance, 
  #        which we expect the classifier to predict
  # @param prediction [String, Symbol] is the predicted class of the instance, 
  #        as output from the classifier
  # @param n [Integer] number of observations to add
  # @raise [ArgumentError] if +n+ is not an Integer
  # @raise [ArgumentError] if +actual+ or +predicted+ are not one of any 
  #        pre-defined labels in matrix
  def add_for(actual, prediction, n = 1)
    validate_label actual, prediction
    if !@matrix.has_key?(actual)
      @matrix[actual] = {}
    end
    predictions = @matrix[actual]
    if !predictions.has_key?(prediction)
      predictions[prediction] = 0
    end

    unless n.class == Integer and n.positive?
      raise ArgumentError.new("add_for requires n to be a positive Integer, but got #{n}")
    end

    @matrix[actual][prediction] += n
  end

  # Returns the number of instances of the given class label which 
  # are incorrectly classified.
  #
  #   false_negative(:positive) = b
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of false negative
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def false_negative(label = @labels.first)
    validate_label label
    predictions = @matrix.fetch(label, {})
    total = 0

    predictions.each_pair do |key, count|
      if key != label 
        total += count
      end
    end

    total
  end

  # Returns the number of instances incorrectly classified with the given 
  # class label.
  #
  #   false_positive(:positive) = c
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of false positive
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def false_positive(label = @labels.first)
    validate_label label
    total = 0

    @matrix.each_pair do |key, predictions|
      if key != label
        total += predictions.fetch(label, 0)
      end
    end

    total
  end

  # The false rate for a given class label is the proportion of instances 
  # incorrectly classified as that label, out of all those instances 
  # not originally of that label.
  #
  #   false_rate(:positive) = c/(c+d)
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of false rate
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def false_rate(label = @labels.first)
    validate_label label
    fp = false_positive(label)
    tn = true_negative(label)

    divide(fp, fp+tn)
  end

  # The F-measure for a given label is the harmonic mean of the precision 
  # and recall for that label. 
  #
  # F = 2*(precision*recall)/(precision+recall)
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of F-measure
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def f_measure(label = @labels.first)
    validate_label label
    2*precision(label)*recall(label)/(precision(label) + recall(label))
  end

  # The geometric mean is the nth-root of the product of the true_rate for 
  # each label.
  #
  #   a1 = a/(a+b)
  #   a2 = d/(c+d)
  #   geometric_mean = Math.sqrt(a1*a2)
  #
  # @return [Float] value of geometric mean
  def geometric_mean
    product = 1

    @matrix.each_key do |key|
      product *= true_rate(key)
    end

    product**(1.0/@matrix.size)
  end

  # The Kappa statistic compares the observed accuracy with an expected 
  # accuracy. 
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of Cohen's Kappa Statistic
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def kappa(label = @labels.first)
    validate_label label
    tp = true_positive(label)
    fn = false_negative(label)
    fp = false_positive(label)
    tn = true_negative(label)
    total = tp+fn+fp+tn

    total_accuracy = divide(tp+tn, tp+tn+fp+fn)
    random_accuracy = divide((tn+fp)*(tn+fn) + (fn+tp)*(fp+tp), total*total)

    divide(total_accuracy - random_accuracy, 1 - random_accuracy)
  end

  # Matthews Correlation Coefficient is a measure of the quality of binary 
  # classifications.
  #
  #   mathews_correlation(:positive) = (a*d - c*b) / sqrt((a+c)(a+b)(d+c)(d+b))
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of Matthews Correlation Coefficient
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def matthews_correlation(label = @labels.first)
    validate_label label
    tp = true_positive(label)
    fn = false_negative(label)
    fp = false_positive(label)
    tn = true_negative(label)

    divide(tp*tn - fp*fn, Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)))
  end

  # The overall accuracy is the proportion of instances which are 
  # correctly labelled.
  #
  #   overall_accuracy = (a+d)/(a+b+c+d)
  #
  # @return [Float] value of overall accuracy
  def overall_accuracy
    total_correct = 0

    @matrix.each_pair do |key, predictions|
      total_correct += true_positive(key)
    end

    divide(total_correct, total)
  end

  # The precision for a given class label is the proportion of instances 
  # classified as that class which are correct.
  #
  #   precision(:positive) = a/(a+c)
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of precision
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def precision(label = @labels.first)
    validate_label label
    tp = true_positive(label)
    fp = false_positive(label)

    divide(tp, tp+fp)
  end

  # The prevalence for a given class label is the proportion of instances 
  # which were classified as of that label, out of the total.
  #
  #   prevalence(:positive) = (a+c)/(a+b+c+d)
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of prevalence
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def prevalence(label = @labels.first)
    validate_label label
    tp = true_positive(label)
    fn = false_negative(label)
    fp = false_positive(label)
    tn = true_negative(label)
    total = tp+fn+fp+tn

    divide(tp+fn, total)
  end

  # The recall is another name for the true rate.
  #
  # @see true_rate
  # @param (see #true_rate)
  # @return (see #true_rate)
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def recall(label = @labels.first)
    validate_label label
    true_rate(label)
  end

  # Sensitivity is another name for the true rate.
  #
  # @see true_rate
  # @param (see #true_rate)
  # @return (see #true_rate)
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def sensitivity(label = @labels.first)
    validate_label label
    true_rate(label)
  end

  # The specificity for a given class label is 1 - false_rate(label)
  # 
  # In two-class case, specificity = 1 - false_positive_rate
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] value of specificity
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def specificity(label = @labels.first)
    validate_label label
    1-false_rate(label)
  end

  # Returns the table in a string format, representing the entries as a 
  # printable table. 
  #
  # @return [String] representation as a printable table.
  def to_s
    ls = labels
    result = ""

    title_line = "Predicted " 
    label_line = ""
    ls.each { |l| label_line << "#{l} " }
    label_line << " " while label_line.size < title_line.size
    title_line << " " while title_line.size < label_line.size
    result << title_line << "|\n" << label_line << "| Actual\n"
    result << "-"*title_line.size << "+-------\n"

    ls.each do |l|
      count_line = ""
      ls.each_with_index do |m, i|
        count_line << "#{count_for(l, m)}".rjust(labels[i].size) << " "
      end
      result << count_line.ljust(title_line.size) << "| #{l}\n"
    end

    result
  end

  # Returns the total number of instances referenced in the matrix.
  #
  #   total = a+b+c+d
  #
  # @return [Integer] total number of instances referenced in the matrix.
  def total
    total = 0

    @matrix.each_value do |predictions|
      predictions.each_value do |count|
        total += count
      end
    end

    total
  end

  # Returns the number of instances NOT of the given class label which 
  # are correctly classified.
  #
  #   true_negative(:positive) = d
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Integer] number of instances not of given label which are correctly classified
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def true_negative(label = @labels.first)
    validate_label label
    total = 0

    @matrix.each_pair do |key, predictions|
      if key != label 
        total += predictions.fetch(key, 0)
      end
    end

    total
  end

  # Returns the number of instances of the given class label which are 
  # correctly classified.
  #
  #   true_positive(:positive) = a
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Integer] number of instances of given label which are correctly classified
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def true_positive(label = @labels.first)
    validate_label label
    predictions = @matrix.fetch(label, {})
    predictions.fetch(label, 0)
  end

  # The true rate for a given class label is the proportion of instances of 
  # that class which are correctly classified.
  #
  #   true_rate(:positive) = a/(a+b)
  #
  # @param label [String, Symbol] of class to use, defaults to first of any pre-defined labels in matrix
  # @return [Float] proportion of instances which are correctly classified
  # @raise [ArgumentError] if +label+ is not one of any pre-defined labels in matrix
  def true_rate(label = @labels.first)
    validate_label label
    tp = true_positive(label)
    fn = false_negative(label)

    divide(tp, tp+fn)
  end

  private

  # A form of "safe divide".
  # Checks if divisor is zero, and returns 0.0 if so. 
  # This avoids a run-time error. 
  # Also, ensures floating point division is done.
  def divide(x, y)
    if y.zero?
      0.0
    else
      x.to_f/y
    end
  end

  # Checks if given label(s) is non-nil and in @labels, or if @labels is empty
  # Raises ArgumentError if not
  def validate_label *labels
    return true if @labels.empty?
    labels.each do |label|
      unless label and @labels.include?(label)
        raise ArgumentError.new("Given label (#{label}) is not in predefined list (#{@labels.join(',')})")
      end
    end
  end
end


